package memdb

import (
	"errors"
	"fmt"
	"math"
	"strings"

	"gitee.com/haodreams/libs/memtable"
	"github.com/alecthomas/participle/v2"
	"github.com/alecthomas/participle/v2/lexer"
)

type StrCompare func(x, y string) bool
type NumMath func(x, y float64) float64
type NumCompare func(x, y float64) bool
type StatisFunc func(table *memtable.Table, idx int) float64 //统计函数

var gStrFunc = map[string]StrCompare{
	"=":  strEQ,
	">":  strGT,
	">=": strGE,
	"<":  strLT,
	"<=": strLE,
	"!=": strNE,
	"<>": strNE,
}
var gNumMath = map[string]NumMath{
	"+": numAdd,
	"-": numSubt,
	"*": numMult,
	"/": numDivi,
	"%": numModu,
}

var gNumFunc = map[string]NumCompare{
	"=":  numEQ,
	">":  numGT,
	">=": numGE,
	"<":  numLT,
	"<=": numLE,
	"!=": numNE,
	"<>": numNE,
}

var gStatisFunc = map[string]StatisFunc{
	"MAX":   memtable.Max,
	"MIN":   memtable.Min,
	"SUM":   memtable.Sum,
	"AVG":   memtable.Avg,
	"COUNT": memtable.Count,
}

// 表的where支持逻辑和数学一维函数， 例如 where a < b * 100,，数学常量放在表达式最后面
// 表的字段暂不支持数学函数 例如 select a * 100
type Database struct {
	mapTable map[string]*memtable.Table
	sqlparse *participle.Parser[SelectStmt]
}

func NewParser() *participle.Parser[SelectStmt] {
	sqlLexer := lexer.MustSimple([]lexer.SimpleRule{
		//{Name: `Keyword`, Pattern: `(?i)\b(SELECT|FROM|TOP|DISTINCT|ALL|WHERE|GROUP|BY|HAVING|UNION|MINUS|EXCEPT|INTERSECT|ORDER|LIMIT|OFFSET|TRUE|FALSE|NULL|IS|NOT|ANY|SOME|BETWEEN|AND|OR|LIKE|AS|IN|FILL)\b`},
		{Name: `Keyword`, Pattern: `(?i)\b(SELECT|FROM|WHERE|GROUP|BY|HAVING|ORDER|LIMIT|OFFSET|TRUE|FALSE|NULL|IS|NOT|BETWEEN|AND|OR|LIKE|AS|IN|FILL)\b`},
		{Name: `Ident`, Pattern: `[a-zA-Z_*][a-zA-Z0-9_]*`},
		{Name: `Number`, Pattern: `[-+]?\d*\.?\d+([eE][-+]?\d+)?`},
		{Name: `String`, Pattern: `'[^']*'|'[^']*'`},
		{Name: `Operators`, Pattern: `<>|!=|<=|>=|[-+*/%,.()=<>]`},
		{Name: "whitespace", Pattern: `\s+`},
	})
	sqlparse := participle.MustBuild[SelectStmt](
		participle.Lexer(sqlLexer),
		participle.Unquote("String"),
		participle.CaseInsensitive("Keyword"),
		// participle.Elide('Comment'),
		// Need to solve left recursion detection first, if possible.
		// participle.UseLookahead(),
	)
	return sqlparse
}

func NewDatabase() *Database {
	db := new(Database)
	db.mapTable = make(map[string]*memtable.Table)
	db.sqlparse = NewParser()
	return db
}

func (m *Database) RegisterTable(table any, key string, names ...string) (err error) {
	t, err := memtable.NewTable(table, key)
	if err != nil {
		return
	}

	name := t.Name()
	if len(names) > 0 {
		name = names[0]
	}

	return m.RegisterMemTable(t, name)
}

func (m *Database) Tables() []*memtable.Table {
	ts := make([]*memtable.Table, len(m.mapTable))
	idx := 0
	for _, tab := range m.mapTable {
		ts[idx] = tab
		idx++
	}
	return ts
}

func (m *Database) RegisterMemTable(table *memtable.Table, name string) (err error) {
	_, ok := m.mapTable[name]
	if ok {
		panic(name + " has registerd")
	}
	m.mapTable[name] = table
	table.SetRegisterName(name)
	return
}

// 获取表
func (m *Database) GetTable(name string) *memtable.Table {
	return m.mapTable[name]
}

// 根据sql查询数据，返回临时表
func (m *Database) Query(sql string) (tempTable *memtable.TempTable, err error) {
	stmt, err := m.sqlparse.ParseString("", sql)
	if err != nil {
		return
	}

	table := m.GetTable(stmt.From)
	if table == nil {
		err = fmt.Errorf("'%s' 表不存在", stmt.From)
		return
	}

	//1 先根据where条件过滤
	for _, where := range stmt.Where {
		table, err = Where(table, where)
		if err != nil {
			return nil, err
		}
	}
	var tabs []*memtable.Table
	if stmt.GroupBy != nil {
		//2 根据group过滤
		tabs = table.GroupBy(stmt.GroupBy.Name)
	} else {
		tabs = []*memtable.Table{table}
	}

	//3 返回需要的临时表数据
	if tabs == nil {
		return nil, errors.New("数据异常")
	}

	isStatis := false
	titles := make([]string, len(stmt.Fields))
	//oldTitles := make([]string, len(titles))
	titleIdx := make([]int, len(titles))
	funcs := make([]string, len(titles))
	for i, filed := range stmt.Fields {
		n := len(filed.Name)
		title := filed.Name[n-1]
		if n > 1 {
			funcs[i] = filed.Name[0]
			//出现多列认为是有统计函数
			isStatis = true
		}
		if filed.As == "" {
			titles[i] = title
		} else {
			titles[i] = filed.As
		}
		if title == "*" {
			titleIdx[i] = -1
		}
		idx := table.ColumnIndex(title)
		if idx >= 0 {
			titleIdx[i] = idx
		} else {
			if title != "*" {
				err = fmt.Errorf("'%s' 列不存在", title)
				return
			}
		}
	}
	n := len(stmt.Fields)
	tempTable = memtable.NewTempTable(titles)
	if isStatis { //如果是统计数据，一个表生成1条数据
		for _, tab := range tabs {
			rows := tab.Rows()
			if len(rows) == 0 {
				continue
			}
			row := tempTable.Append()
			for i := 0; i < n; i++ {
				idx := titleIdx[i]
				fc := funcs[i]
				if fc == "" {
					ptr := rows[0].Data[idx]
					if ptr != nil {
						if pnum, ok := ptr.(*float64); ok {
							if pnum != nil {
								row.Data[i] = *pnum
							}
						} else if pstr, ok := ptr.(*string); ok {
							if pstr != nil {
								row.Data[i] = *pstr
							}
						}
					}
				} else {
					fc = strings.ToUpper(fc)
					f, ok := gStatisFunc[fc]
					if ok {
						row.Data[i] = f(tab, idx)
					} else {
						err = fmt.Errorf("'%s' 不支持的函数", fc)
						return nil, err
					}
				}
			}
		}
	} else {
		for _, tab := range tabs {
			rows := tab.Rows()
			if len(rows) == 0 {
				continue
			}
			for _, record := range rows {
				row := tempTable.Append()
				for i := 0; i < n; i++ {
					idx := titleIdx[i]
					ptr := record.Data[idx]
					if ptr != nil {
						if pnum, ok := ptr.(*float64); ok {
							if pnum != nil {
								row.Data[i] = *pnum
							}
						} else if pstr, ok := ptr.(*string); ok {
							if pstr != nil {
								row.Data[i] = *pstr
							}
						}
					}
				}
			}
		}
	}

	if stmt.FillBy != nil {
		if len(stmt.FillBy.Name) == 2 || len(stmt.FillBy.Vals) > 0 {
			tempTable, err = tempTable.FillBy(stmt.FillBy.Name[0], stmt.FillBy.Name[1], stmt.FillBy.Vals)
			if err != nil {
				return
			}
		}
	}

	//对临时表进行排序
	if stmt.OrderBy != nil {
		err = tempTable.OrderBy(stmt.OrderBy.Name, stmt.OrderBy.DESC)
	}

	return
}

// where 查询
func Where(t *memtable.Table, cond *Condition) (table *memtable.Table, err error) {
	idx := t.ColumnIndex(cond.Name)
	if idx < 0 {
		return nil, fmt.Errorf("'%s' 列名不存在", cond.Name)
	}
	if t.ColumnTypes()[idx] == memtable.TypeNumber {
		switch {
		case cond.CondRHS.Compare != nil:
			cmp := cond.CondRHS.Compare
			f := gNumFunc[cmp.Operator]
			if f == nil {
				return nil, fmt.Errorf("'%s'操作符不支持", cmp.Operator)
			}
			if cmp.Number == nil {
				if cmp.Math == nil {
					return nil, fmt.Errorf("'%s'参数错误", cmp.Operator)
				}
				mt := cmp.Math
				fmath := gNumMath[mt.Operator]
				if fmath == nil {
					return nil, fmt.Errorf("'%s'操作符不支持", mt.Operator)
				}
				idx2 := t.ColumnIndex(mt.Name)
				if idx2 < 0 {
					return nil, fmt.Errorf("'%s' (数学函数)列名不存在", cond.Name)
				}

				table = t.WhereMath(idx, idx2, mt.Number, f, fmath)
			} else {
				table = t.WhereNum(idx, *cmp.Number, f)
			}
		case cond.CondRHS.In != nil:
			table = t.WhereNumArray(idx, cond.CondRHS.In.NumArray, numIn)
		case cond.CondRHS.Is != nil:
			table = t.WhereNull(idx, cond.CondRHS.Is.Not)
		default:
			return nil, errors.New("操作符不支持")
		}
		return
	}
	switch {
	case cond.CondRHS.Compare != nil:
		if cond.CondRHS.Compare.String == nil {
			return nil, fmt.Errorf("'%s'参数错误", cond.CondRHS.Compare.Operator)
		}
		f := gStrFunc[cond.CondRHS.Compare.Operator]
		if f == nil {
			return nil, fmt.Errorf("'%s'操作符不支持", cond.CondRHS.Compare.Operator)
		}
		table = t.WhereStr(idx, *cond.CondRHS.Compare.String, f)
	case cond.CondRHS.In != nil:
		table.WhereStrArray(idx, cond.CondRHS.In.StrArray, strIn)
	case cond.CondRHS.Like != nil:
		table = t.WhereStr(idx, *cond.CondRHS.Compare.String, like)
	case cond.CondRHS.Is != nil:
		table = table.WhereNull(idx, cond.CondRHS.Is.Not)
	}
	return
}

func numAdd(x, y float64) float64 {
	return x + y
}

func numSubt(x, y float64) float64 {
	return x - y
}

func numMult(x, y float64) float64 {
	return x * y
}

func numDivi(x, y float64) float64 {
	if y == 0 {
		return math.NaN()
	}
	return x / y
}

func numModu(x, y float64) float64 {
	return float64(int(x) % int(y))
}

func numIn(x float64, y []float64) bool {
	for _, y1 := range y {
		if x == y1 {
			return true
		}
	}
	return false
}

func numEQ(x, y float64) bool {
	return x == y
}

func numGT(x, y float64) bool {
	return x > y
}
func numGE(x, y float64) bool {
	return x >= y
}
func numNE(x, y float64) bool {
	return x != y
}

func numLT(x, y float64) bool {
	return x < y
}

func numLE(x, y float64) bool {
	return x <= y
}

func like(x string, y string) bool {
	if strings.HasPrefix(y, "%") {
		if strings.HasSuffix(y, "%") {
			return strings.Contains(x, y[1:len(y)-1])
		}
		return strings.HasPrefix(x, y[1:])
	}
	if strings.HasSuffix(y, "%") {
		return strings.HasSuffix(x, y[:len(y)-1])
	}
	return x == y
}

func strIn(x string, y []string) bool {
	for _, y1 := range y {
		if x == y1 {
			return true
		}
	}
	return false
}

func strEQ(x, y string) bool {
	return x == y
}

func strGT(x, y string) bool {
	return x > y
}
func strGE(x, y string) bool {
	return x >= y
}
func strNE(x, y string) bool {
	return x != y
}

func strLT(x, y string) bool {
	return x < y
}

func strLE(x, y string) bool {
	return x <= y
}
